Fix #8239: Enhance SoftclDiceLoss and SoftDiceclDiceLoss with DiceLoss-compatible API#8703
Fix #8239: Enhance SoftclDiceLoss and SoftDiceclDiceLoss with DiceLoss-compatible API#8703aymuos15 wants to merge 6 commits intoProject-MONAI:devfrom
Conversation
… with additional parameters - Add include_background, to_onehot_y, sigmoid, softmax, other_act, and reduction parameters - Fix argument order in forward() to match other losses (y_pred, y_true) - Add proper input validation and comprehensive docstrings - These changes make the losses consistent with DiceLoss API and fix zero loss issues Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughReplaces the prior minimal soft_dice export with a new SoftclDiceLoss class and a combined SoftDiceclDiceLoss that blends Dice and clDice via an alpha weight. Both classes accept configurable activations (sigmoid/softmax/custom), optional to_onehot target conversion, include_background toggle, smoothing, iterative clDice parameters, and LossReduction-based reductions. Constructors validate arguments and deprecation aliases; forward methods validate shapes/channels, apply activation and optional one-hot conversion, delegate core computation to Dice/clDice logic, and raise on invalid configurations. Tests were rewritten and parameterized to exercise activations, reductions, shape/channel mismatches, CUDA, and edge cases. Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@monai/losses/cldice.py`:
- Around line 217-223: The computation of tprec/tsens and cl_dice can produce
NaN when self.smooth == 0 and the denominator sums (torch.sum(skel_pred, ...) or
torch.sum(skel_true, ...)) are zero; update the logic in the CLDice computation
(references: skel_pred, skel_true, input, tprec, tsens, cl_dice, self.smooth,
reduce_axis) to guard denominators by a small positive epsilon (or enforce
self.smooth > 0) — e.g., compute denom_pred = torch.sum(skel_pred,
dim=reduce_axis).clamp_min(eps) (and similarly for denom_true) or use
torch.where to replace zero denominators with eps before dividing, so
tprec/tsens and cl_dice never become NaN; optionally add a docstring note that
smooth must be positive.
🧹 Nitpick comments (4)
monai/losses/cldice.py (2)
187-187: Addstacklevel=2to warnings.Without stacklevel, warnings point to this line instead of the caller's location.
Proposed fix
- warnings.warn("single channel prediction, `softmax=True` ignored.") + warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)Same applies to lines 196 and 202.
340-340: Addstacklevel=2here as well.Proposed fix
- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)tests/losses/test_cldice_loss.py (2)
36-71: Consider adding a test case forinclude_background=False.Current cases cover sigmoid, softmax, and to_onehot_y, but not background exclusion.
82-87: Consider using@unittest.skipUnlessdecorator.More idiomatic than early return.
Proposed fix
+ `@unittest.skipUnless`(torch.cuda.is_available(), "CUDA not available") def test_cuda(self): - if not torch.cuda.is_available(): - return loss = SoftclDiceLoss() result = loss(ONES_2D["input"].cuda(), ONES_2D["target"].cuda()) np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/losses/cldice.pytests/losses/test_cldice_loss.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/losses/test_cldice_loss.pymonai/losses/cldice.py
🧬 Code graph analysis (1)
monai/losses/cldice.py (3)
monai/losses/dice.py (1)
DiceLoss(31-229)monai/networks/utils.py (1)
one_hot(170-220)monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.11)
monai/losses/cldice.py
158-158: Avoid specifying long messages outside the exception class
(TRY003)
160-160: Avoid specifying long messages outside the exception class
(TRY003)
187-187: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
196-196: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
202-202: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
208-208: Avoid specifying long messages outside the exception class
(TRY003)
233-233: Avoid specifying long messages outside the exception class
(TRY003)
326-329: Avoid specifying long messages outside the exception class
(TRY003)
332-335: Avoid specifying long messages outside the exception class
(TRY003)
340-340: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
🔇 Additional comments (5)
monai/losses/cldice.py (4)
14-23: LGTM!Imports are appropriate for the enhanced functionality.
119-167: LGTM!Constructor properly validates activation options and aligns with DiceLoss API.
291-312: LGTM!Correctly centralizes one-hot conversion to avoid double application in composed losses.
314-348: LGTM!Input validation and combined loss computation are correct.
tests/losses/test_cldice_loss.py (1)
111-134: LGTM!Good coverage of combined loss functionality and error cases.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@monai/losses/cldice.py`:
- Around line 161-163: The code currently validates smooth but not the
skeletonization iteration count iter_; add validation in the clDice constructor
(or the function where self.iter is set) to ensure iter_ is a positive integer:
check that iter_ is an int (or castable to int) and greater than 0, and raise a
ValueError with a clear message if not; then assign self.iter = int(iter_) so
downstream skeletonize/skel operations use a safe positive integer (refer to the
self.iter assignment and the iter_ parameter in the clDice class/constructor).
🧹 Nitpick comments (6)
monai/losses/cldice.py (6)
188-189: Addstacklevel=2to warnings.Without it, warnings point to this file instead of the caller's location.
Proposed fix
if n_pred_ch == 1: - warnings.warn("single channel prediction, `softmax=True` ignored.") + warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)Apply similarly to lines 198 and 204.
Also applies to: 197-198, 203-204
232-233: Inconsistentreduction="none"behavior with DiceLoss.DiceLoss applies
.view(-1)forreduction="none". Here it's left as-is. This may cause issues when combining losses or stacking results.Proposed fix for consistency
elif self.reduction == LossReduction.NONE.value: - pass # keep per-batch values + cl_dice = cl_dice.view(-1)
293-315: Consider validatingalpharange.
alphaoutside[0, 1]would produce unusual weighting. While possibly intentional, a validation or warning could prevent mistakes.Proposed fix
if smooth <= 0: raise ValueError(f"smooth must be a positive value but got {smooth}.") + if not 0.0 <= alpha <= 1.0: + warnings.warn(f"alpha={alpha} is outside [0, 1], loss weighting may be unusual.", stacklevel=2)
344-344: Addstacklevel=2here as well.- warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
171-181: MissingReturnssection in docstring.Per coding guidelines, return values should be documented.
Proposed addition
Raises: AssertionError: When input and target (after one hot transform if set) have different shapes. + + Returns: + torch.Tensor: The computed clDice loss, reduced according to `self.reduction`. """
318-328: MissingReturnssection in docstring.Same as
SoftclDiceLoss.forward.Proposed addition
Raises: ValueError: When number of dimensions for input and target are different. ValueError: When number of channels for target is neither 1 nor the same as input. + + Returns: + torch.Tensor: The weighted combination of Dice and clDice losses. """
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/losses/cldice.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/losses/cldice.py
🧬 Code graph analysis (1)
monai/losses/cldice.py (3)
monai/losses/dice.py (1)
DiceLoss(31-229)monai/networks/utils.py (1)
one_hot(170-220)monai/utils/enums.py (1)
LossReduction(253-264)
🪛 Ruff (0.14.11)
monai/losses/cldice.py
158-158: Avoid specifying long messages outside the exception class
(TRY003)
160-160: Avoid specifying long messages outside the exception class
(TRY003)
162-162: Avoid specifying long messages outside the exception class
(TRY003)
189-189: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
198-198: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
204-204: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
210-210: Avoid specifying long messages outside the exception class
(TRY003)
235-235: Avoid specifying long messages outside the exception class
(TRY003)
294-294: Avoid specifying long messages outside the exception class
(TRY003)
330-333: Avoid specifying long messages outside the exception class
(TRY003)
336-339: Avoid specifying long messages outside the exception class
(TRY003)
344-344: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
🔇 Additional comments (3)
monai/losses/cldice.py (3)
14-23: LGTM!Imports are appropriate for the new functionality.
212-225: LGTM!clDice computation is correct. The
smoothvalidation in__init__addresses the NaN concern from previous reviews, and the1e-8epsilon provides additional safety for the harmonic mean.
318-352: LGTM!Forward method properly validates dimensions/channels, centralizes one-hot conversion, and combines losses correctly.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
ericspod
left a comment
There was a problem hiding this comment.
Thanks @aymuos15 for this enhancement, I had a few minor comments. I haven't run the code so can't compare results against the original versions of these loss functions, if you haven't done so yourself please do as a sanity check. The added tests are more thorough than the originals which is great.
| self.softmax = softmax | ||
| self.other_act = other_act | ||
|
|
||
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
y_true and y_pred will have to be marked as deprecated arguments here, in case anyone has used them by name.
monai/losses/cldice.py
Outdated
| torch.sum(skel_true, dim=reduce_axis) + self.smooth | ||
| ) | ||
| cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens) | ||
| cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens + 1e-8) |
There was a problem hiding this comment.
Instead of hard-coding 1e-8, the class could have members for smoothing the numerator and denominator like DiceLoss.
| self.alpha = alpha | ||
| self.to_onehot_y = to_onehot_y | ||
|
|
||
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Have addressed all the above, thank you!
Just on this, the previous version was doing a global mean for the batch stats, but the current change is following the per sample MONAI convention. I missed this initialy. |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@monai/losses/cldice.py`:
- Around line 263-329: Add validation in the __init__ of this loss to ensure the
alpha parameter is within [0, 1]; if alpha is not between 0 and 1 (inclusive)
raise a ValueError with a clear message. Place the check before assigning
self.alpha in the constructor shown (the __init__ that constructs self.dice and
self.cldice and sets self.alpha/self.to_onehot_y) so invalid values are rejected
early.
- Around line 120-166: The constructor currently only checks iter_ >= 0 but
allows non-integer values which will later break range() in soft_skel; in
__init__ (method __init__) validate iter_ is an integer or an integral float
(e.g., 3.0) and convert to int early: if isinstance(iter_, int) use it, elif
isinstance(iter_, float) and iter_.is_integer() cast to int, else raise
TypeError; perform this check/cast before the existing non-negative check and
then assign to self.iter so soft_skel receives a safe integer.
🧹 Nitpick comments (5)
monai/losses/cldice.py (3)
182-191: Add Returns section to the forward docstring.
Required by the project docstring guidelines.📝 Example
Raises: AssertionError: When input and target (after one hot transform if set) have different shapes. + Returns: + Loss value. Scalar for "mean"/"sum", or per-batch tensor for "none".
304-306: Set base_Loss.reductionfor consistency.
Right now the module reports the default reduction even if a different one is passed.♻️ Suggested tweak
- super().__init__() + super().__init__(reduction=LossReduction(reduction).value)
337-346: Add Returns section to the forward docstring.
Required by the project docstring guidelines.📝 Example
Raises: ValueError: When number of dimensions for input and target are different. ValueError: When number of channels for target is neither 1 nor the same as input. + Returns: + Loss value. Scalar for "mean"/"sum", or per-batch tensor for "none".tests/losses/test_cldice_loss.py (2)
74-109: Add docstrings for new test definitions.
Project guidelines require Google-style docstrings for each definition.
111-135: Add ato_onehot_y=Truecombined-loss test.
This new option inSoftDiceclDiceLossisn’t exercised yet.
- Replace `smooth` parameter with `smooth_nr`/`smooth_dr` to match DiceLoss API - Add validation for `iter_` parameter (must be non-negative) - Add `@deprecated_arg` decorators for backward compatibility with `y_true`/`y_pred` - Add `stacklevel=2` to all warnings for proper caller location Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
879b182 to
caf39ef
Compare
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
225766b to
568ec36
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (5)
tests/losses/test_cldice_loss.py (3)
101-107: No test exercises the deprecatedy_pred/y_trueargument mapping.The
@deprecated_argwrappers on bothforward()methods are untested — neither a warning assertion nor a value assertion confirms they work.🧪 Suggested addition
def test_deprecated_args(self): import warnings loss = SoftclDiceLoss() inp = ONES_2D["input"] tgt = ONES_2D["target"] with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") result = loss.forward(y_pred=inp, y_true=tgt) self.assertTrue(any("y_pred" in str(warning.message) for warning in w)) np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/losses/test_cldice_loss.py` around lines 101 - 107, Add a unit test that exercises the deprecated y_pred/y_true argument mapping on SoftclDiceLoss.forward: instantiate SoftclDiceLoss, call loss.forward(y_pred=inp, y_true=tgt) inside a warnings.catch_warnings(record=True) with warnings.simplefilter("always"), assert a deprecation warning mentions "y_pred" or "y_true", and assert the returned loss value equals the expected numeric result (e.g., 0.0 for ONES_2D input) using numpy/testing or torch assertions; this verifies the `@deprecated_arg` wrapper on SoftclDiceLoss.forward actually emits the warning and returns the correct value.
81-86: Useself.skipTest()instead of barereturnin CUDA guards.
returnmakes the test appear as passed in reports;self.skipTest()correctly marks it as skipped.♻️ Proposed fix (same pattern for both classes)
def test_cuda(self): if not torch.cuda.is_available(): - return + self.skipTest("CUDA not available")Also applies to: 117-122
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/losses/test_cldice_loss.py` around lines 81 - 86, The test_cuda methods currently use a bare return when CUDA is unavailable; replace that with self.skipTest("CUDA not available") so the test is marked skipped instead of passed. Update the test method named test_cuda (the one that constructs SoftclDiceLoss and calls loss(ONES_2D["input"].cuda(), ONES_2D["target"].cuda())) and the analogous test_cuda in the other test class (the second occurrence around the ONES_2D usage) to call self.skipTest("CUDA not available") at the top of the method when torch.cuda.is_available() is False.
66-71:COMBINED_CASEShas noto_onehot_y=Truetest forSoftDiceclDiceLoss.The
to_onehot_ypath inSoftDiceclDiceLoss.forward()(Lines 353–358 of cldice.py) is entirely unexercised. Add at least one case, e.g.,alpha=0.5, single-channelB1H[WD]target with multi-channel input.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/losses/test_cldice_loss.py` around lines 66 - 71, COMBINED_CASES currently lacks a test exercising the to_onehot_y=True branch in SoftDiceclDiceLoss.forward(); add a new case to COMBINED_CASES that uses alpha=0.5 (or similar), sets to_onehot_y=True, and supplies a single-channel target (e.g., shape B1HWD or B1HW) with a corresponding multi-channel prediction so the loss converts target to one-hot; this will ensure the to_onehot_y path in SoftDiceclDiceLoss.forward() is executed during tests.monai/losses/cldice.py (2)
306-315:smooth_nr=1.0/smooth_dr=1.0passed toDiceLossdiffer fromDiceLoss's own defaults (1e-5).Users swapping from standalone
DiceLosstoSoftDiceclDiceLosswill see a different Dice component numerically. This is intentional (clDice convention), but the docstring is silent on it. Consider adding a note like: "Note:smooth_nrandsmooth_drdefault to1.0(clDice convention), which differs from standaloneDiceLossdefaults of1e-5."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/losses/cldice.py` around lines 306 - 315, The Dice component of SoftDiceclDiceLoss passes smooth_nr=1.0 and smooth_dr=1.0 into DiceLoss (see the self.dice = DiceLoss(...) call), which differs from DiceLoss's defaults of 1e-5; update the SoftDiceclDiceLoss class/function docstring to include a short note stating that smooth_nr and smooth_dr default to 1.0 following the clDice convention (and that this differs from standalone DiceLoss defaults of 1e-5) so users are aware of the numerical difference when switching between losses.
190-218:SoftclDiceLoss.forward()skips dimension/channel pre-validation.
SoftDiceclDiceLoss.forward()raises a clearValueErrorfor dim/channel mismatches (Lines 343–351), butSoftclDiceLoss.forward()only catches shape mismatches after all preprocessing via an opaqueAssertionError. Standalone callers passing a single-channel target against a multi-channel input withto_onehot_y=Falseget no actionable message.♻️ Proposed fix
+ if input.dim() != target.dim(): + raise ValueError( + f"input and target must have the same number of dimensions, " + f"got {input.shape} and {target.shape}." + ) + if target.shape[1] != 1 and target.shape[1] != input.shape[1]: + raise ValueError( + f"number of target channels is neither 1 nor equal to input channels, " + f"got {input.shape} and {target.shape}." + ) n_pred_ch = input.shape[1]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/losses/cldice.py` around lines 190 - 218, SoftclDiceLoss.forward currently only raises an AssertionError after preprocessing, making channel/shape errors opaque; update SoftclDiceLoss.forward to perform explicit channel/shape validation early (before applying sigmoid/softmax/other_act and before to_onehot_y handling) and raise a clear ValueError like SoftDiceclDiceLoss.forward does: check n_pred_ch = input.shape[1] versus target channel dimension (or whether to_onehot_y is False and target has single channel) and verify compatibility (e.g., target.shape == input.shape or target channels == 1 when to_onehot_y=True), and replace the late AssertionError(f"ground truth has different shape...") with a ValueError containing an actionable message mentioning the function name and the mismatched shapes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/losses/cldice.py`:
- Around line 153-157: The docstring's Raises section for the cldice loss is
missing the exceptions related to the iter_ parameter; update the Raises block
in the cldice docstring (the docstring for the clDice loss class/function that
accepts the iter_ parameter) to include a TypeError when iter_ is not an int and
a ValueError when iter_ is negative, matching the existing phrasing/style used
for other exceptions (i.e., list "TypeError: When ``iter_`` is not an ``int``."
and "ValueError: When ``iter_`` is negative.").
- Around line 297-302: Add a ValueError entry to the existing Raises section of
the docstring in monai/losses/cldice.py to document that the parameter alpha
must be within [0, 1]; specifically, add a line like "ValueError: When ``alpha``
is not in ``[0, 1]``." to the Raises block for the callable that accepts the
alpha parameter (the function/class docstring where alpha is validated).
- Around line 159-170: Validate the smoothing denominator to prevent
division-by-zero: in the constructor where iter_, smooth_nr and smooth_dr are
handled (the __init__ that sets self.iter and self.smooth_dr), ensure smooth_dr
is a numeric (coerce to float) and strictly greater than 0 (raise ValueError if
<= 0 or not convertible), then assign self.smooth_dr = float(smooth_dr); this
prevents torch.sum(...)+self.smooth_dr from becoming zero when iter_=0 and
inputs are all zeros.
---
Nitpick comments:
In `@monai/losses/cldice.py`:
- Around line 306-315: The Dice component of SoftDiceclDiceLoss passes
smooth_nr=1.0 and smooth_dr=1.0 into DiceLoss (see the self.dice = DiceLoss(...)
call), which differs from DiceLoss's defaults of 1e-5; update the
SoftDiceclDiceLoss class/function docstring to include a short note stating that
smooth_nr and smooth_dr default to 1.0 following the clDice convention (and that
this differs from standalone DiceLoss defaults of 1e-5) so users are aware of
the numerical difference when switching between losses.
- Around line 190-218: SoftclDiceLoss.forward currently only raises an
AssertionError after preprocessing, making channel/shape errors opaque; update
SoftclDiceLoss.forward to perform explicit channel/shape validation early
(before applying sigmoid/softmax/other_act and before to_onehot_y handling) and
raise a clear ValueError like SoftDiceclDiceLoss.forward does: check n_pred_ch =
input.shape[1] versus target channel dimension (or whether to_onehot_y is False
and target has single channel) and verify compatibility (e.g., target.shape ==
input.shape or target channels == 1 when to_onehot_y=True), and replace the late
AssertionError(f"ground truth has different shape...") with a ValueError
containing an actionable message mentioning the function name and the mismatched
shapes.
In `@tests/losses/test_cldice_loss.py`:
- Around line 101-107: Add a unit test that exercises the deprecated
y_pred/y_true argument mapping on SoftclDiceLoss.forward: instantiate
SoftclDiceLoss, call loss.forward(y_pred=inp, y_true=tgt) inside a
warnings.catch_warnings(record=True) with warnings.simplefilter("always"),
assert a deprecation warning mentions "y_pred" or "y_true", and assert the
returned loss value equals the expected numeric result (e.g., 0.0 for ONES_2D
input) using numpy/testing or torch assertions; this verifies the
`@deprecated_arg` wrapper on SoftclDiceLoss.forward actually emits the warning and
returns the correct value.
- Around line 81-86: The test_cuda methods currently use a bare return when CUDA
is unavailable; replace that with self.skipTest("CUDA not available") so the
test is marked skipped instead of passed. Update the test method named test_cuda
(the one that constructs SoftclDiceLoss and calls loss(ONES_2D["input"].cuda(),
ONES_2D["target"].cuda())) and the analogous test_cuda in the other test class
(the second occurrence around the ONES_2D usage) to call self.skipTest("CUDA not
available") at the top of the method when torch.cuda.is_available() is False.
- Around line 66-71: COMBINED_CASES currently lacks a test exercising the
to_onehot_y=True branch in SoftDiceclDiceLoss.forward(); add a new case to
COMBINED_CASES that uses alpha=0.5 (or similar), sets to_onehot_y=True, and
supplies a single-channel target (e.g., shape B1HWD or B1HW) with a
corresponding multi-channel prediction so the loss converts target to one-hot;
this will ensure the to_onehot_y path in SoftDiceclDiceLoss.forward() is
executed during tests.
| Raises: | ||
| TypeError: When ``other_act`` is not an ``Optional[Callable]``. | ||
| ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. | ||
| Incompatible values. | ||
|
|
There was a problem hiding this comment.
Raises section omits iter_-related exceptions.
Both the TypeError (non-integer iter_) and ValueError (negative iter_) raised at Lines 165–167 are absent from the docstring.
📝 Proposed fix
Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
Incompatible values.
+ TypeError: When ``iter_`` is not an integer.
+ ValueError: When ``iter_`` is negative.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Raises: | |
| TypeError: When ``other_act`` is not an ``Optional[Callable]``. | |
| ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. | |
| Incompatible values. | |
| Raises: | |
| TypeError: When ``other_act`` is not an ``Optional[Callable]``. | |
| ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. | |
| Incompatible values. | |
| TypeError: When ``iter_`` is not an integer. | |
| ValueError: When ``iter_`` is negative. |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@monai/losses/cldice.py` around lines 153 - 157, The docstring's Raises
section for the cldice loss is missing the exceptions related to the iter_
parameter; update the Raises block in the cldice docstring (the docstring for
the clDice loss class/function that accepts the iter_ parameter) to include a
TypeError when iter_ is not an int and a ValueError when iter_ is negative,
matching the existing phrasing/style used for other exceptions (i.e., list
"TypeError: When ``iter_`` is not an ``int``." and "ValueError: When ``iter_``
is negative.").
| super().__init__(reduction=LossReduction(reduction).value) | ||
| if other_act is not None and not callable(other_act): | ||
| raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") | ||
| if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: | ||
| raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") | ||
| if not isinstance(iter_, int): | ||
| raise TypeError(f"iter_ must be an integer but got {type(iter_).__name__}.") | ||
| if iter_ < 0: | ||
| raise ValueError(f"iter_ must be a non-negative integer but got {iter_}.") | ||
| self.iter = iter_ | ||
| self.smooth = smooth | ||
| self.smooth_nr = float(smooth_nr) | ||
| self.smooth_dr = float(smooth_dr) |
There was a problem hiding this comment.
smooth_dr=0 still produces NaN; no validation added.
With iter_=0 and an all-zero input, skel_pred is all-zero, so torch.sum(skel_pred, dim=reduce_axis) + self.smooth_dr becomes 0 when smooth_dr=0.0, causing NaN. The defaults of 1.0 guard normal usage, but an explicit smooth_dr=0.0 slips through silently.
🛡️ Proposed fix
+ if smooth_dr <= 0:
+ raise ValueError(f"smooth_dr must be positive but got {smooth_dr}.")
self.iter = iter_
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)🧰 Tools
🪛 Ruff (0.15.1)
[warning] 161-161: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 163-163: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 165-165: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 167-167: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@monai/losses/cldice.py` around lines 159 - 170, Validate the smoothing
denominator to prevent division-by-zero: in the constructor where iter_,
smooth_nr and smooth_dr are handled (the __init__ that sets self.iter and
self.smooth_dr), ensure smooth_dr is a numeric (coerce to float) and strictly
greater than 0 (raise ValueError if <= 0 or not convertible), then assign
self.smooth_dr = float(smooth_dr); this prevents torch.sum(...)+self.smooth_dr
from becoming zero when iter_=0 and inputs are all zeros.
| Raises: | ||
| TypeError: When ``other_act`` is not an ``Optional[Callable]``. | ||
| ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. | ||
| Incompatible values. | ||
|
|
||
| """ |
There was a problem hiding this comment.
Raises section missing ValueError for alpha out of [0, 1].
Line 305 raises it, but the docstring doesn't document it.
📝 Proposed fix
Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
Incompatible values.
+ ValueError: When ``alpha`` is not in ``[0, 1]``.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Raises: | |
| TypeError: When ``other_act`` is not an ``Optional[Callable]``. | |
| ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. | |
| Incompatible values. | |
| """ | |
| Raises: | |
| TypeError: When ``other_act`` is not an ``Optional[Callable]``. | |
| ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. | |
| Incompatible values. | |
| ValueError: When ``alpha`` is not in ``[0, 1]``. | |
| """ |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@monai/losses/cldice.py` around lines 297 - 302, Add a ValueError entry to the
existing Raises section of the docstring in monai/losses/cldice.py to document
that the parameter alpha must be within [0, 1]; specifically, add a line like
"ValueError: When ``alpha`` is not in ``[0, 1]``." to the Raises block for the
callable that accepts the alpha parameter (the function/class docstring where
alpha is validated).
Summary
include_background,to_onehot_y,sigmoid,softmax,other_act, andreductionparameters toSoftclDiceLossandSoftDiceclDiceLossforward()to match MONAI convention (input,targetinstead ofy_true,y_pred)DiceLossFixes #8239
Changes
These changes make the clDice losses consistent with the
DiceLossAPI, addressing the issues reported in #8239 where users encountered zero loss due to missing preprocessing options.Checklist
./runtests.sh --codeformat)y_true, y_predtoinput, target)Test plan